{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!sync; echo 3 > /proc/sys/vm/drop_caches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from time import time\n",
    "import re\n",
    "import glob\n",
    "import warnings\n",
    "\n",
    "# tools for data preproc/loading\n",
    "import torch\n",
    "import rmm\n",
    "import nvtabular as nvt\n",
    "from nvtabular.ops import Normalize,  Categorify,  LogOp, FillMissing, Clip, get_embedding_sizes\n",
    "from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader\n",
    "from nvtabular.utils import device_mem_size\n",
    "\n",
    "# tools for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define some information about where to get our data\n",
    "INPUT_DATA_DIR = os.environ.get('INPUT_DATA_DIR', '/raid/criteo/tests/crit_int_pq')\n",
    "OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/raid/criteo/tests/test_dask') # where we'll save our procesed data to\n",
    "BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32768))\n",
    "AMP = os.environ.get(\"AMP\", \"true\") \n",
    "AMP = True if AMP.lower() in \"true\" else False\n",
    "PARTS_PER_CHUNK = int(os.environ.get('PARTS_PER_CHUNK', 2))\n",
    "SHUFFLE = os.environ.get(\"SHUFFLE\", False)\n",
    "NUM_TRAIN_DAYS = 23 # number of days worth of data to use for training, the rest will be used for validation\n",
    "\n",
    "# define our dataset schema\n",
    "CONTINUOUS_COLUMNS = ['I' + str(x) for x in range(1,14)]\n",
    "CATEGORICAL_COLUMNS =  ['C' + str(x) for x in range(1,27)]\n",
    "LABEL_COLUMNS = ['label']\n",
    "COLUMNS = CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS + LABEL_COLUMNS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_train_dir = os.path.join(OUTPUT_DATA_DIR, 'train/')\n",
    "output_valid_dir = os.path.join(OUTPUT_DATA_DIR, 'valid/')\n",
    "! mkdir -p $output_train_dir\n",
    "! mkdir -p $output_valid_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmm.reinitialize(pool_allocator=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<pyarrow._parquet.FileMetaData object at 0x7f926cc81d10>\n",
      "  created_by: \n",
      "  num_columns: 40\n",
      "  num_rows: 20970668\n",
      "  num_row_groups: 154\n",
      "  format_version: 1.0\n",
      "  serialized_size: 403853 <pyarrow._parquet.FileMetaData object at 0x7f926cc81770>\n",
      "  created_by: \n",
      "  num_columns: 40\n",
      "  num_rows: 17836814\n",
      "  num_row_groups: 37\n",
      "  format_version: 1.0\n",
      "  serialized_size: 101741\n"
     ]
    }
   ],
   "source": [
    "train_paths = glob.glob(os.path.join(output_train_dir, \"*.parquet\"))\n",
    "valid_paths = glob.glob(os.path.join(output_valid_dir, \"*.parquet\"))\n",
    "\n",
    "import pyarrow as pa\n",
    "train_stats = pa.parquet.read_metadata(train_paths[0])\n",
    "valid_stats = pa.parquet.read_metadata(valid_paths[0])\n",
    "print(train_stats, valid_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvtabular.loader.torch import IterDL\n",
    "train_loader = IterDL(train_paths[:1], batch_size=BATCH_SIZE, shuffle=SHUFFLE)\n",
    "valid_loader = IterDL(valid_paths[:1], batch_size=BATCH_SIZE, shuffle=SHUFFLE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_DROPOUT_RATE = 0.04\n",
    "DROPOUT_RATES = [0.001, 0.01]\n",
    "HIDDEN_DIMS = [1000, 500]\n",
    "LEARNING_RATE = 0.001\n",
    "EPOCHS = 1\n",
    "embeddings = {'C1': (7599500, 16),\n",
    " 'C10': (5345303, 16),\n",
    " 'C11': (561810, 16),\n",
    " 'C12': (242827, 16),\n",
    " 'C13': (11, 6),\n",
    " 'C14': (2209, 16),\n",
    " 'C15': (10616, 16),\n",
    " 'C16': (100, 16),\n",
    " 'C17': (4, 3),\n",
    " 'C18': (968, 16),\n",
    " 'C19': (15, 7),\n",
    " 'C2': (33521, 16),\n",
    " 'C20': (7838519, 16),\n",
    " 'C21': (2580502, 16),\n",
    " 'C22': (6878028, 16),\n",
    " 'C23': (298771, 16),\n",
    " 'C24': (11951, 16),\n",
    " 'C25': (97, 16),\n",
    " 'C26': (35, 12),\n",
    " 'C3': (17022, 16),\n",
    " 'C4': (7339, 16),\n",
    " 'C5': (20046, 16),\n",
    " 'C6': (4, 3),\n",
    " 'C7': (7068, 16),\n",
    " 'C8': (1377, 16),\n",
    " 'C9': (63, 16)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/rapids/lib/python3.7/site-packages/torch/cuda/__init__.py:104: UserWarning: \n",
      "Tesla V100-SXM2-32GB with CUDA capability sm_70 is not compatible with the current PyTorch installation.\n",
      "The current PyTorch install supports CUDA capabilities sm_60.\n",
      "If you want to use the Tesla V100-SXM2-32GB GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/\n",
      "\n",
      "  warnings.warn(incompatible_device_warn.format(device_name, capability, \" \".join(arch_list), device_name))\n"
     ]
    }
   ],
   "source": [
    "from nvtabular.framework_utils.torch.models import Model\n",
    "from nvtabular.framework_utils.torch.utils import process_epoch\n",
    "model = Model(\n",
    "    embedding_table_shapes=embeddings,\n",
    "    num_continuous=len(CONTINUOUS_COLUMNS),\n",
    "    emb_dropout=EMBEDDING_DROPOUT_RATE,\n",
    "    layer_hidden_dims=HIDDEN_DIMS,\n",
    "    layer_dropout_rates=DROPOUT_RATES,\n",
    ").to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "def rmspe_func(y_pred, y):\n",
    "    \"Return y_pred and y to non-log space and compute RMSPE\"\n",
    "    y_pred, y = torch.exp(y_pred) - 1, torch.exp(y) - 1\n",
    "    pct_var = (y_pred - y) / y\n",
    "    return (pct_var**2).mean().pow(0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_transform(batch, cont_cols=CONTINUOUS_COLUMNS, cat_cols=CATEGORICAL_COLUMNS, label_cols=LABEL_COLUMNS):\n",
    "    x_cat = torch.tensor(batch[sorted(cat_cols)].values).type(torch.LongTensor).cuda() \n",
    "    x_cont = torch.tensor(batch[cont_cols].values).type(torch.FloatTensor).cuda()\n",
    "    y = torch.tensor(batch[label_cols[0]].values).type(torch.FloatTensor).cuda()\n",
    "    return x_cat, x_cont, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: no kernel image is available for execution on the device",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-a9be539c4da9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m                                           \u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m                                           \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_transform\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m                                           \u001b[0mamp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mAMP\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m                                          )\n\u001b[1;32m     10\u001b[0m     \u001b[0mtrain_rmspe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrmspe_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/jp_docs/nvtabular/nvtabular/framework_utils/torch/utils.py\u001b[0m in \u001b[0;36mprocess_epoch\u001b[0;34m(dataloader, model, train, optimizer, loss_func, transform, amp, device)\u001b[0m\n\u001b[1;32m     61\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mamp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m                     \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_cat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_cont\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m                     \u001b[0my_pred_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m                     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/rapids/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/jp_docs/nvtabular/nvtabular/framework_utils/torch/models.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x_cat, x_cont)\u001b[0m\n\u001b[1;32m     72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_cat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_cont\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m         \u001b[0mx_cat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitial_cat_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_cat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     75\u001b[0m         \u001b[0mx_cont\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitial_cont_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_cont\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx_cat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_cont\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/rapids/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/jp_docs/nvtabular/nvtabular/framework_utils/torch/layers/embeddings.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     45\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: no kernel image is available for execution on the device"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCHS):\n",
    "    start_train=time()\n",
    "    train_loss, y_pred, y = process_epoch(train_loader, \n",
    "                                          model, \n",
    "                                          train=True, \n",
    "                                          optimizer=optimizer,\n",
    "                                          transform=batch_transform,\n",
    "                                          amp=AMP,\n",
    "                                         )\n",
    "    train_rmspe = rmspe_func(y_pred, y)\n",
    "    train_time=time() - start_train\n",
    "    y_pred = None\n",
    "    y = None\n",
    "    start_valid=time()\n",
    "    valid_loss, y_pred, y = process_epoch(valid_loader,\n",
    "                                          model, \n",
    "                                          train=False, \n",
    "                                          optimizer=optimizer,\n",
    "                                          transform=batch_transform,\n",
    "                                          amp=AMP,\n",
    "                                         )\n",
    "    valid_rmspe = rmspe_func(y_pred, y)\n",
    "    valid_time = time() - start_valid\n",
    "    y_pred = None\n",
    "    y = None\n",
    "    print(f\"Train:{train_time} + Valid:{valid_time} = EpochTotal:{train_time + valid_time}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
